{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import adaboost\n",
    "import utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 一个简单的人造数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "datMat =  [[1.  2.1]\n",
      " [2.  1.1]\n",
      " [1.3 1. ]\n",
      " [1.  1. ]\n",
      " [2.  1. ]]\n",
      "classLabels =  [1.0, 1.0, -1.0, -1.0, 1.0]\n"
     ]
    }
   ],
   "source": [
    "datMat,classLabels=utils.loadSimpData()\n",
    "print (\"datMat = \", datMat)\n",
    "print (\"classLabels = \", classLabels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Expect Result:  \n",
    "\n",
    "datMat =  [[1.  2.1]  \n",
    " [2.  1.1]  \n",
    " [1.3 1. ]  \n",
    " [1.  1. ]  \n",
    " [2.  1. ]]  \n",
    "classLabels =  [1.0, 1.0, -1.0, -1.0, 1.0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bestStump =  {'feature': 0, 'value': 1.3, 'ineq': 'lt'}\n",
      "bestClasEst =  [-1.  1. -1. -1.  1.]\n",
      "minError =  0.2\n"
     ]
    }
   ],
   "source": [
    "D = np.ones(5)/5   # 平均权重\n",
    "bestStump, bestClasEst, minError = adaboost.buildStump(datMat,classLabels,D)\n",
    "print (\"bestStump = \", bestStump)\n",
    "print (\"bestClasEst = \", bestClasEst)\n",
    "print (\"minError = \", minError)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Expect Result:  \n",
    "\n",
    "bestStump =  {'feature': 0, 'value': 1.3, 'ineq': 'lt'}  \n",
    "bestClasEst =  [-1.  1. -1. -1.  1.]  \n",
    "minError =  0.2  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "weekClassifier =  [{'feature': 0, 'value': 1.3, 'ineq': 'lt', 'alpha': 0.6931471805599453}, {'feature': 1, 'value': 1.0, 'ineq': 'lt', 'alpha': 0.9729550745276565}, {'feature': 0, 'value': 0.9, 'ineq': 'lt', 'alpha': 0.8958797346140273}]\n",
      "aggPredict =  [ 1.17568763  2.56198199 -0.77022252 -0.77022252  0.61607184]\n"
     ]
    }
   ],
   "source": [
    "datArr,labelArr = utils.loadSimpData()\n",
    "weekClassifier, aggPredict = adaboost.adaBoostTrainDS(datMat,classLabels,9)\n",
    "print (\"weekClassifier = \", weekClassifier)\n",
    "print (\"aggPredict = \", aggPredict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Expect Result:  \n",
    "\n",
    "D= [0.2 0.2 0.2 0.2 0.2]  \n",
    "predict:  [-1.  1. -1. -1.  1.]  \n",
    "aggpredict= [-0.69314718  0.69314718 -0.69314718 -0.69314718  0.69314718]  \n",
    "total error:  0.2  \n",
    "D= [0.5   0.125 0.125 0.125 0.125]  \n",
    "predict:  [ 1.  1. -1. -1. -1.]  \n",
    "aggpredict= [ 0.27980789  1.66610226 -1.66610226 -1.66610226 -0.27980789]  \n",
    "total error:  0.2  \n",
    "D= [0.28571429 0.07142857 0.07142857 0.07142857 0.5       ]  \n",
    "predict:  [1. 1. 1. 1. 1.]  \n",
    "aggpredict= [ 1.17568763  2.56198199 -0.77022252 -0.77022252  0.61607184]  \n",
    "total error:  0.0  \n",
    "weekClassifier =  [{'feature': 0, 'value': 1.3, 'ineq': 'lt', 'alpha': 0.6931471805599453}, {'feature': 1, 'value': 1.0, 'ineq': 'lt', 'alpha': 0.9729550745276565}, {'feature': 0, 'value': 0.9, 'ineq': 'lt', 'alpha': 0.8958797346140273}]  \n",
    "aggPredict =  [ 1.17568763  2.56198199 -0.77022252 -0.77022252  0.61607184]  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 1., -1.])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adaboost.adaClassify([[5, 5],[0,0]],weekClassifier)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[ 2.56198199 -2.56198199]  \n",
    "array([ 1., -1.])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Horse Colic Data Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy =  0.7611940298507462\n"
     ]
    }
   ],
   "source": [
    "datArr,labelArr = utils.loadDataSet('horseColicTraining2.txt')\n",
    "classifierArray, _ = adaboost.adaBoostTrainDS(datArr,labelArr,10)\n",
    "testArr,testLabelArr = utils.loadDataSet('horseColicTest2.txt')\n",
    "prediction10 = adaboost.adaClassify(testArr,classifierArray)\n",
    "accuracy = utils.calculateAccuray(prediction10, testLabelArr)\n",
    "print (\"accuracy = \", accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Expect Result:  \n",
    "accuracy =  0.7611940298507462"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "参考：https://windmising.gitbook.io/liu-yu-bo-play-with-machine-learning/10-1/10-7"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ROC曲线"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "datArr,labelArr = utils.loadDataSet('horseColicTraining2.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "classifierArray,aggClassEst = adaboost.adaBoostTrainDS(datArr,labelArr,10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAQcElEQVR4nO3df4hlZ33H8fcnSVMpTbTpRoj7w93SDTiGksiQNQg1oi2bgNl/rN2E0FqCq7axUKWQYokS/2kVKxW21W0bokKM0T/MIiuB2ogi7jYjptGspGwTTcaEZjUx/UM0Br/9Y+7IdfbO3DMz98fcZ94vWLjnnOfe+T57Zz958pznnJOqQpI0+86bdgGSpNEw0CWpEQa6JDXCQJekRhjoktSIC6b1g3fs2FF79+6d1o+XpJn0zW9+84dVdemgY1ML9L1797KwsDCtHy9JMynJ91c75pSLJDXCQJekRhjoktQIA12SGmGgS1IjhgZ6kjuTPJPkO6scT5KPJTmT5OEkrxl9mZKkYbqM0O8CDq5x/Dpgf+/PEeCfN1+WJGm9hq5Dr6qvJtm7RpNDwKdq6T68J5O8LMllVfX0iGqUpA27+9QT3PfQD6Zdxq+Ye8XFvP/Nrx75547iwqKdwJN924u9fecEepIjLI3i2bNnzwh+tKTWjDqATz3+LAAH9l0yss/cqkYR6Bmwb+BTM6rqGHAMYH5+3idrSI3aTCiPOoAP7LuEQ1fu5KYD7Q8iRxHoi8Duvu1dwFMj+FxJM+q+h37A6af/j7nLLl73e7dTAI/aKAL9OHBrknuAA8Dzzp9Ls28zo+zlMP/sO64ZcVVay9BAT/IZ4FpgR5JF4P3ArwFU1ceBE8D1wBngJ8CfjatYSRu33oDezNTH3GUXc+jKnet+nzanyyqXG4ccL+AvRlaRpM7WE9LrDWinPmbP1G6fK2njloN8PSFtQLfPQJdm0PJJR0Na/Qx0aUb0T6940lGDGOjSFjRobrx/esWTjhrEQJe2oEHruJ1e0TAGurRFOaWi9fJ+6JLUCEfo0gR1XTe+0cvmtb05QpcmaHlufBhPemojHKFLE7A8Mne5ocbJQJfGpH96pX/JoSNvjYuBLvWM88EKLjnUJBjoUs9m7uE9iCGuSTPQ1YRRjK6d39asM9A1E4YF9igeW+bKEs06A11bxlqhPSywnd6QDHRtIWvNYRvY0nAGurYU57CljTPQNVWD7vEtaWMMdE3cahfceFJS2hwDXRMz6DmYzo1Lo2Oga6zWuvzdEJdGy0DXWPWvXDHIpfEy0DV2rlyRJsNA16atdUGQK1ekyTHQtSGrzY2v5MoVaXIMdG2Ic+PS1mOga1188o60dRnoArrfftYn70hbl4G+jXWdB+/n9Iq0dRno25jz4FJbOgV6koPAPwLnA/9aVX+34vge4JPAy3ptbquqEyOuVZswaErFeXCpLecNa5DkfOAocB0wB9yYZG5Fs78F7q2qq4DDwD+NulBtzvJovJ9LCqW2dBmhXw2cqarHAJLcAxwCTve1KWD56pGXAk+NskiNhqNxqW1DR+jATuDJvu3F3r5+HwBuTrIInADePeiDkhxJspBk4ezZsxsoV5K0mi4j9AzYVyu2bwTuqqqPJLkG+HSSK6rqF7/ypqpjwDGA+fn5lZ+hTfISfGl76zJCXwR2923v4twplVuAewGq6hvAS4AdoyhQ3Q2aJ1/mfLnUvi4j9AeB/Un2AT9g6aTnTSvaPAG8EbgryatYCnTnVKbAeXJp+xoa6FX1YpJbgftZWpJ4Z1U9kuQOYKGqjgPvBf4lyV+xNB3ztqpySmUCfCanpGWd1qH31pSfWLHv9r7Xp4HXjbY0ddF/cZDTKtL25pWiDXCaRRIY6DPB1SuSuuiyykVT5uoVSV04Qp8RTqtIGsYRuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRXlg0RWtd0t/Py/sldWGgT1h/iJ96/FkADuy7ZM33eHm/pC4M9Anrv93tgX2XcOjKndx0YM+0y5LUAAN9Crwvi6Rx8KSoJDXCQJekRhjoktQI59AnZHl1i0sQJY2LgT4iw9aU9y9RdAmipHEw0DdpOciHrSl3iaKkcTPQO1ptBL5y5G1gS5oWA72j1ea/DXJJW4WBvg5eECRpKzPQBxg0veLqFElbnevQB1ieXunnDbIkbXWO0Ffh9IqkWeMIXZIaYaBLUiM6BXqSg0keTXImyW2rtHlrktNJHkly92jLlCQNM3QOPcn5wFHgD4BF4MEkx6vqdF+b/cDfAK+rqueSvHxcBUuSBusyQr8aOFNVj1XVC8A9wKEVbd4OHK2q5wCq6pnRljkZd596gj/+xDfOWeEiSbOgS6DvBJ7s217s7et3OXB5kq8nOZnk4KAPSnIkyUKShbNnz26s4jHqvxrUJYqSZk2XZYsZsK8GfM5+4FpgF/C1JFdU1Y9/5U1Vx4BjAPPz8ys/Y0twuaKkWdVlhL4I7O7b3gU8NaDNfVX186p6HHiUpYCXJE1Il0B/ENifZF+SC4HDwPEVbb4AvAEgyQ6WpmAeG2WhkqS1DQ30qnoRuBW4H/gucG9VPZLkjiQ39JrdD/woyWngAeCvq+pH4ypaknSuTpf+V9UJ4MSKfbf3vS7gPb0/M6X/RlzegEvSLNv2V4r234jL1S2SZpk358KVLZLasO1H6JLUCgNdkhqxLaZcVnvAM3giVFI7tsUIfdATiJZ5IlRSK5odoQ9ajuiJT0kta3aE7nJESdtNsyN0cDmipO2l2RG6JG03BrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqhIEuSY0w0CWpEQa6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEY090zRu0898csHRM9ddvG0y5GkiWluhN4f5oeu3DntciRpYpoboQPMXXYxn33HNdMuQ5ImqtMIPcnBJI8mOZPktjXavSVJJZkfXYmSpC6GBnqS84GjwHXAHHBjkrkB7S4C/hI4NeoiJUnDdRmhXw2cqarHquoF4B7g0IB2HwQ+BPx0hPVJkjrqEug7gSf7thd7+34pyVXA7qr64loflORIkoUkC2fPnl13sZKk1XUJ9AzYV788mJwHfBR477APqqpjVTVfVfOXXnpp9yolSUN1CfRFYHff9i7gqb7ti4ArgK8k+R7wWuC4J0YlabK6BPqDwP4k+5JcCBwGji8frKrnq2pHVe2tqr3ASeCGqloYS8WSpIGGBnpVvQjcCtwPfBe4t6oeSXJHkhvGXaAkqZtOFxZV1QngxIp9t6/S9trNlyVJWq/mLv2XpO3KQJekRjRxL5flOywC3mVR0rY104G+HOSnHn8WgAP7LvEui5K2rZkO9OVb5R7YdwmHrtzJTQf2TLskSZqamQ508Fa5krTMk6KS1AgDXZIaYaBLUiMMdElqhIEuSY0w0CWpETO3bNGrQiVpsJkboS9fTAR4Vagk9Zm5ETp4MZEkDTJzI3RJ0mAGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqRKdAT3IwyaNJziS5bcDx9yQ5neThJF9O8srRlypJWsvQQE9yPnAUuA6YA25MMrei2beA+ar6PeDzwIdGXagkaW1dRuhXA2eq6rGqegG4BzjU36CqHqiqn/Q2TwK7RlumJGmYLoG+E3iyb3uxt281twBfGnQgyZEkC0kWzp49271KSdJQXQI9A/bVwIbJzcA88OFBx6vqWFXNV9X8pZde2r1KSdJQXR4SvQjs7tveBTy1slGSNwHvA15fVT8bTXmSpK66jNAfBPYn2ZfkQuAwcLy/QZKrgE8AN1TVM6MvU5I0zNBAr6oXgVuB+4HvAvdW1SNJ7khyQ6/Zh4HfBD6X5KEkx1f5OEnSmHSZcqGqTgAnVuy7ve/1m0ZclyRpnbxSVJIaYaBLUiMMdElqhIEuSY0w0CWpEQa6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqhIEuSY0w0CWpEQa6JDXCQJekRnQK9CQHkzya5EyS2wYc//Ukn+0dP5Vk76gLlSStbWigJzkfOApcB8wBNyaZW9HsFuC5qvpd4KPA34+6UEnS2rqM0K8GzlTVY1X1AnAPcGhFm0PAJ3uvPw+8MUlGV6YkaZgLOrTZCTzZt70IHFitTVW9mOR54LeBH/Y3SnIEOAKwZ8+eDRU894qLN/Q+SWpdl0AfNNKuDbShqo4BxwDm5+fPOd7F+9/86o28TZKa12XKZRHY3be9C3hqtTZJLgBeCjw7igIlSd10CfQHgf1J9iW5EDgMHF/R5jjwp73XbwH+o6o2NAKXJG3M0CmX3pz4rcD9wPnAnVX1SJI7gIWqOg78G/DpJGdYGpkfHmfRkqRzdZlDp6pOACdW7Lu97/VPgT8abWmSpPXwSlFJaoSBLkmNMNAlqREGuiQ1ItNaXZjkLPD9Db59ByuuQt0G7PP2YJ+3h830+ZVVdemgA1ML9M1IslBV89OuY5Ls8/Zgn7eHcfXZKRdJaoSBLkmNmNVAPzbtAqbAPm8P9nl7GEufZ3IOXZJ0rlkdoUuSVjDQJakRWzrQt+PDqTv0+T1JTid5OMmXk7xyGnWO0rA+97V7S5JKMvNL3Lr0Oclbe9/1I0nunnSNo9bhd3tPkgeSfKv3+339NOoclSR3JnkmyXdWOZ4kH+v9fTyc5DWb/qFVtSX/sHSr3v8Bfge4EPgvYG5Fmz8HPt57fRj47LTrnkCf3wD8Ru/1u7ZDn3vtLgK+CpwE5qdd9wS+5/3At4Df6m2/fNp1T6DPx4B39V7PAd+bdt2b7PPvA68BvrPK8euBL7H0xLfXAqc2+zO38gh9Oz6cemifq+qBqvpJb/MkS0+QmmVdvmeADwIfAn46yeLGpEuf3w4crarnAKrqmQnXOGpd+lzA8kODX8q5T0abKVX1VdZ+ctsh4FO15CTwsiSXbeZnbuVAH/Rw6p2rtamqF4Hlh1PPqi597ncLS/+Fn2VD+5zkKmB3VX1xkoWNUZfv+XLg8iRfT3IyycGJVTceXfr8AeDmJIssPX/h3ZMpbWrW++99qE4PuJiSkT2ceoZ07k+Sm4F54PVjrWj81uxzkvOAjwJvm1RBE9Dle76ApWmXa1n6v7CvJbmiqn485trGpUufbwTuqqqPJLmGpaegXVFVvxh/eVMx8vzayiP07fhw6i59JsmbgPcBN1TVzyZU27gM6/NFwBXAV5J8j6W5xuMzfmK06+/2fVX186p6HHiUpYCfVV36fAtwL0BVfQN4CUs3sWpVp3/v67GVA307Ppx6aJ970w+fYCnMZ31eFYb0uaqer6odVbW3qvaydN7ghqpamE65I9Hld/sLLJ0AJ8kOlqZgHptolaPVpc9PAG8ESPIqlgL97ESrnKzjwJ/0Vru8Fni+qp7e1CdO+0zwkLPE1wP/zdLZ8ff19t3B0j9oWPrCPwecAf4T+J1p1zyBPv878L/AQ70/x6dd87j7vKLtV5jxVS4dv+cA/wCcBr4NHJ52zRPo8xzwdZZWwDwE/OG0a95kfz8DPA38nKXR+C3AO4F39n3HR3t/H98exe+1l/5LUiO28pSLJGkdDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUiP8H2WLwck9TmHkAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "utils.plotROC(aggClassEst,labelArr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
